"""
Print main pricing table and figure.
"""
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import statsmodels.formula.api as sm
from matplotlib.ticker import FuncFormatter

from cat_rfs.code.utils.table import Table
from cat_rfs.code.utils.tools import save_stat

SETTINGS = {'main': {'beta_var': 'beta_sheet',
                     'yield_var': 'sheet_yield_usd',
                     'wsst': False,
                     'fig_file': 'fig1',
                     'table_file': 'table3',
                     'fm_name': 'Main'},
            'trace': {'beta_var': 'beta_act',
                      'yield_var': 'act_yield_usd',
                      'wsst': False,
                      'fig_file': 'fig1_trace',
                      'table_file': 'table3_trace',
                      'fm_name': 'Trace'},
            'corr75': {'beta_var': 'beta_corr75',
                       'yield_var': 'sheet_yield_usd',
                       'wsst': False,
                       'fig_file': 'fig1_corr75',
                       'table_file': 'table3_corr75',
                       'fm_name': 'Correlation = 0.75'},
            'corr50': {'beta_var': 'beta_corr50',
                       'yield_var': 'sheet_yield_usd',
                       'wsst': False,
                       'fig_file': 'fig1_corr50',
                       'table_file': 'table3_corr50',
                       'fm_name': 'Correlation = 0.50'},
            'corr25': {'beta_var': 'beta_corr25',
                       'yield_var': 'sheet_yield_usd',
                       'wsst': False,
                       'fig_file': 'fig1_corr25',
                       'table_file': 'table3_corr25',
                       'fm_name': 'Correlation = 0.25'},
            'loss': {'beta_var': 'beta_loss',
                     'yield_var': 'sheet_yield_usd',
                     'wsst': False,
                     'fig_file': 'fig1_loss',
                     'table_file': 'table3_loss',
                     'fm_name': 'Zero-yield'},
            'ew': {'beta_var': 'beta_sheet_ew',
                   'yield_var': 'sheet_yield_usd',
                   'wsst': False,
                   'fig_file': 'fig1_ew',
                   'table_file': 'table3_ew',
                   'fm_name': 'EW portfolio'},
            'random_peril': {'beta_var': 'beta_random_peril',
                             'yield_var': 'sheet_yield_usd',
                             'wsst': False,
                             'fig_file': 'fig1_random_peril',
                             'table_file': 'table3_random_peril',
                             'fm_name': 'Random perils'},
            'equal_el': {'beta_var': 'beta_equal_el',
                         'yield_var': 'sheet_yield_usd',
                         'wsst': False,
                         'fig_file': 'fig1_equal_el',
                         'table_file': 'table3_equal_el',
                         'fm_name': 'Equal $el$'},
            'equalize_all': {'beta_var': 'beta_equalize_all',
                             'yield_var': 'sheet_yield_usd',
                             'wsst': False,
                             'fig_file': 'fig1_equalize_all',
                             'table_file': 'table3_equalize_all',
                             'fm_name': 'Full placebo'},
            'wsst': {'beta_var': 'beta_sheet',
                     'yield_var': 'sheet_yield_usd',
                     'wsst': True,
                     'fig_file': 'fig1_wsst',
                     'table_file': 'table3_wsst',
                     'fm_name': 'WSST'},
            }


def print_table3(df, output='console', setting='main'):
    beta = SETTINGS[setting]['beta_var']
    yield_ = SETTINGS[setting]['yield_var']
    tname = SETTINGS[setting]['table_file']
    fname = SETTINGS[setting]['fig_file']

    df = df.copy()
    df = df.sort_values(['date', 'CUSIP9']).reset_index(drop=True).copy()

    if SETTINGS[setting]['wsst']:
        df = df.drop(columns=['expected_loss']).rename(columns={'expected_loss_wsst': 'expected_loss'})
        beta += '_wsst'

    if yield_ == 'act_yield_usd':
        df = df[(df['date'] >= '12/31/2004') & (df['date'] <= '12/31/2018')]

    df['er'] = df[yield_] - df['expected_loss'] / 100 - df['rf'] / 100

    df = df.dropna(subset=[beta, 'er'])

    df['er_m'] = df['er'] * df['size']
    df2 = df.groupby('date', as_index=False)[['er_m', 'size']].sum()
    df2['er_m'] = df2['er_m'] / df2['size']
    df = df.drop(columns=['er_m'])
    df = df.merge(df2[['date', 'er_m']], on='date')

    df['er_model'] = df[beta] * df['er_m']

    dfc = df.copy()

    res_jun = []
    res_all = []
    alphas = []
    lambdas = []
    lambdasm = []
    lambdas0 = []
    lambdas0m = []

    for t in df['date'].unique():
        df2 = df[(df['date'] == t)].reset_index()
        er_m = df2['er_m'].unique()[0]
        ols = sm.ols(formula=f'er ~ {beta}', data=df2).fit()

        a = ols.params['Intercept'] * 100
        alphas += [a]

        la = ols.params[beta] * 100
        lambdas += [la]

        lam = (ols.params[beta] - er_m) * 100
        lambdasm += [lam]

        r2 = ols.rsquared

        ols2 = sm.ols(formula=f'er ~ -1 + {beta}', data=df2).fit()
        la0 = ols2.params[beta] * 100
        lambdas0 += [la0]
        lam0 = (ols2.params[beta] - er_m) * 100
        lambdas0m += [lam0]

        if pd.to_datetime(t).quarter == 2:
            res_jun.append([pd.to_datetime(t).year, a, la, lam, r2, la0, lam0, ols2.rsquared, ols.nobs])

        res_all.append([pd.to_datetime(t).year, a, la, lam, r2, la0, lam0, ols2.rsquared, ols.nobs])

    res_df_jun = pd.DataFrame(res_jun, columns=['year', 'alpha', 'lambda', 'lambda-Erm', 'Rsquared', 'lambda0',
                                                'lambda0-Erm', 'Rsquared0', 'N'])

    for c in ['year', 'N']:
        res_df_jun[c] = res_df_jun[c].astype(int).astype(str)

    T = Table(9, width='auto', justs=['l', 'd', 'd', 'd', 'd', 'd', 'd', 'd', 'd'],
              caption='Pricing of catastrophe market risk',
              label=tname, footer=None)

    T.add_panel(res_df_jun, formatting="%.2f",
                supheader_cols=[1, (2, 5), (6, 8), 9], supheaders=['', '(1)', '(2)', ''],
                headers=['$t$', '$\\lambda_{0,t}$', '$\\lambda_{cat,t}$',
                         '$\\lambda^-_{cat,t}$', '$R^2$', '$\\lambda^0_{cat,t}$',
                         '$\\lambda^{0,-}_{cat,t}$', '$R^2$', '$N$'])

    To = Table(9,  justs=['l', 'd', 'd', 'd', 'd', 'd', 'd', 'd', 'd'], data_only=True)

    To.add_panel(res_df_jun, formatting="%.2f",
                supheader_cols=[1, (2, 5), (6, 8), 9], supheaders=['', '(1)', '(2)', ''],
                headers=['$t$', '$\\lambda_{0,t}$', '$\\lambda_{cat,t}$',
                         '$\\lambda^-_{cat,t}$', '$R^2$', '$\\lambda^0_{cat,t}$',
                         '$\\lambda^{0,-}_{cat,t}$', '$R^2$', '$N$'])

    res_df_all = pd.DataFrame(res_all, columns=['year', 'alpha', 'lambda_', 'lambda_Erm', 'Rsquared', 'lambda0',
                                                'lambda0_Erm', 'Rsquared0', 'N'])

    lags = 4
    ols = sm.ols('alpha ~ 1', data=res_df_all).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    a = ols.params['Intercept']
    a_se = ols.bse['Intercept']

    ols = sm.ols('lambda_ ~ 1', data=res_df_all).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    la = ols.params['Intercept']
    la_se = ols.bse['Intercept']

    ols = sm.ols('lambda_Erm ~ 1', data=res_df_all).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    lam = ols.params['Intercept']
    lam_se = ols.bse['Intercept']

    ols = sm.ols('lambda0 ~ 1', data=res_df_all).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    la0 = ols.params['Intercept']
    la0_se = ols.bse['Intercept']

    ols = sm.ols('lambda0_Erm ~ 1', data=res_df_all).fit(cov_type='HAC', cov_kwds={'maxlags': lags}, use_t=True)
    lam0 = ols.params['Intercept']
    lam0_se = ols.bse['Intercept']

    N = res_df_all.shape[0]
    R2 = res_df_all['Rsquared'].mean()
    R20 = res_df_all['Rsquared0'].mean()

    data2 = pd.DataFrame([['FM', a, la, lam, R2, la0, lam0, R20, N],
                          ['', a_se, la_se, lam_se, '', la0_se, lam0_se, '', '']])

    T.add_panel(data2, formatting="%.2f", regression_specs={'stars': (0.1, 0.05, 0.01), 'dof': (N-1),
                                                            'skip_cols': [5, 8, 9]})

    if output == 'paper':
        T.print_table(f'cat_rfs/output/tables/{tname}.tex')
    else:
        T.print_table()

    data2 = pd.DataFrame([[SETTINGS[setting]['fm_name'], a, la, lam, R2, la0, lam0, R20, N],
                          ['', a_se, la_se, lam_se, '', la0_se, lam0_se, '', '']])

    T2 = Table(9, data_only=True, justs=['l', 'd', 'd', 'd', 'd', 'd', 'd', 'd', 'd'])
    T2.add_panel(data2, formatting="%.2f", regression_specs={'stars': (0.1, 0.05, 0.01), 'dof': (N-1),
                                                             'skip_cols': [5, 8, 9]})

    if output == 'paper':
        T.print_table(f'cat_rfs/output/tables/fm_only/{tname}_fm_only.tex')

    # %% Fig 1
    df = dfc.copy()
    df = df[df['date'].dt.quarter == 2]  # Jun only
    df['error'] = df['er'] - df['er_model']

    df['errorm'] = df['error'] - df['error'].mean()
    df['erm'] = df['er'] - df['er'].mean()

    assert not df[['error', 'er', 'errorm', 'erm']].isnull().any().any()
    1 - (df['error'] ** 2).sum() / (df['er'] ** 2).sum()
    r2 = 1 - (df['errorm'] ** 2).sum() / (df['erm'] ** 2).sum()

    if setting == 'main':
        save_stat(r2 * 100, 'r2', formatting='%.0f', output=output)

    sns.set_palette('Greys_r')
    cp = sns.color_palette()

    with sns.color_palette([cp[0], cp[2], cp[4]]):
        ax = df.plot.scatter(x='er_model', y='er', s=5)

        ax.set_ylim([0, 0.21])
        ax.set_xlim([0, 0.21])
        ax.yaxis.set_major_formatter(FuncFormatter('{0:.0%}'.format))
        ax.xaxis.set_major_formatter(FuncFormatter('{0:.0%}'.format))

        ax.set_ylabel('Observed premium\n\n' + r'$E_t\left(R^e_{i}\right)$')
        ax.set_xlabel('Predicted premium\n' + r'$\hat{\beta}_{i,t}E_t\left(R^e_{cat}\right)$', ma='center')

        ax.text(0.0052, 0.19, f'$R^2=$ {r2:.2f}')

        ax.plot(ax.get_xlim(), ax.get_ylim(), ls='--', c='.3')

        if output == 'paper':
            plt.savefig(f'cat_rfs/output/figures/{fname}.eps')
            plt.close()
        else:
            plt.show()

    pass
